Introduction and Goals¶
Assuming that if we can accurately predict which individuals have a higher risk of strokes negative outcomes can be eliminated or mitigated through lifestyle changes, additional testing, monitoring etc. we still need to balance the costs such additional services would be required with their potential benefits.
So we want to identify the highest possible number of potential strokes (i.e. maximize recall) while still mantaining high precision. This would suggest that we should use optimize our models based on the F1 score for the stroke = 1 class. However, we also need to take into account the costs associated with either outcome:
- false positive: cost medium, additional tests and other possibly services will be provided "unnecessarily" to individuals who are have a low risk.
- false negative: very high costs, would require immediate hospitalization and might result in death
So we can tolerate a much higher proportion of false positives than false negatives. The exact ratio would depend on a more in depth cost analysis (which could performed by healtchare and insurance providers).
Therefore, in our analysis we'll focus on maximizing the recall/class accuracy for stroke = 1 (the proportion of false positive should of course still be considered and minimized as a secondary target).
Core Assumptions:¶
- The cost of a false positive is higher than the cost of a false negative
- Risk factors that might increase the likelihood of a stroke significantly also affect other health issues with a high mortality rate
- i.e. people who are at very high risk of stroke are likely to have a lower life expectancy and die before they had a chance to have one. This would likely mean that our model would asign lower significance/importance to factors such as having a heart disease, being obese, having diabetes etc. while at the same time over estimating the effect of factors which are less correlated to other diseases.
- The dataset is likely not representative and some overfitting is unavoidable
EDA & Model¶
Our ideal baseline would be the "simple" algorithms used by doctors and healtcare providers based on risk factors such as:
- age
- blood pressure
- etc.
One important aspect to consider is that maximizing the overall performance of the model. Classifyinga "high-risk" individual as a "low-risk" carries a much bigger cost than doing the opposite.
Therefore we'll use two metrics when tunning our model:
- macro f1 score
- accuracy for the minority
strokeclass
1.1 Analysis of individuals features and their distributions¶
The charts below show the distribution of all the features included in the dataset:
- Numerical features are displayed using a KDE and Boxen plots with additional testing for normality.
- Value counts are show for non-numerical features
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
1.2 Relationships Between Features¶
invalid value encountered in format_annotation (vectorized)
Because the datatypes of features vary we had to use different methods to measure the strength and significance of each pair:
- Chi-Squared Test: Assesses independence between two categorical variables. For bool-bool pairs due to categorical nature.
- Point Biserial Correlation: Measures correlation between a binary and a continuous variable. For bool-numerical pairs to account for mixed data types.
- Spearman's Rank Correlation: Assesses monotonic relationship between two continuous variables. Used for numerical-numerical pairs (for non-normally distributed data).
Since the Chi-Squared test outputs an unbound statistic/value which can't be directly compared to pointbiserialr or Spearman Rank we have converted them to a ` Cramér's V:` value which is normalized between 0 and 1. This was done to make the values in the matrix more uniform however we must note that Cramér's V and Spearman's correlation coefficients are fundamentally different statistics and generally can't be directly compared.
Text(0.5, 1.025, 'Weight and Age')
{}.{} Risk Factor Analysis¶
In this part we'll look into the relationship between specific risk factors which we would assume to be signficantly related to the likelyhood of having a stroke (both based on correlation and subject knowledge):
- age
- hypertension
- heart_disease
- avg_glucose_level
- bmi
- smoking_status
The KDE plots show the likelihood of having a stroke at a specific age if the patient has any of the listed risk factor (the Y axis is relative to the full sample of individual with the risk factor not just people who have the condition and had a stroke)
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 1200x600 with 0 Axes>
We can see that the number of risk factors on average increases until the age of ~60. Afterward it start slightly decreasing. This might is likely a case of survivor bias as most of them tend to have a negative effect on life expectancy.
This is chart shows individual KDE density curves for each subgroup based on age (it can be interpreted similarly to a histogram).
Interestingly the difference is most prominent below ~65, afterwards the effect of having just 1 or 2 risk factors is much lower.
We can see that people who do are not overweight, do not smoke, do not have elevated glucose levels or heart issues only have a much lower probability of having a stroke as long as they are younger than 60.
Gennerally most of the risk factors besides having a heart disease seem to have a similar effect below the age of 60, afterwards having diabetes/etc. or hypertension have a much higher effect.
1.5 PCA¶
We have attempted to use PCA to reduce the dimensionality of the dataset.
This might be necessary for datasets which include very high numbers of features. Since this specific dataset is very simple and includes a very low number of columns this was only done for informative/educational purposes.
Additionally, we have included binary/categorical variables which also is generally not advisable in real world cases.
While PCA can be used a preprocessing step (and we have expirmenting with using it for simple logistic or SVM models) this is generally not necessary for simple datasets like this.
'Total Feature Count: 10'
PCA was done using a Sklrean pipeline which handles standardization for numerical variables.
We can see that the dataset (not including the target variable) could effectively be reduced to 8 components (which preserves about 80% of variance) since this isn't that much lower than the total number of variables it's not particularly useful for ML or even visualization purposes.
2. ML Models¶
We have used various different models . Our process included these steps:
- Define separate configurations for each model based on target variables/metrics used for tunning (see
src/model_config.pyandshared/ml_config_core.py). We have tested these models:
- XGBoost
- CatBoost
- LGBM
- SVM
- Random Forest
- Custom ensemble model (log + SVM + KNN with a soft voting classifier)
Training and validation were performed using Stratified KFolds (5 folds)
- Hyperparameter tuning was performed for each model. Because the dataset is heavily imbalanced we have using various different target metrics:
- macro F1
- recall (only target class)
- F1 (only target class)
- Various
Builtin class weights parameters were used for all the model besides the ensemble one which uses SMOTE, ADASYN, standard oversampling etc. The results for each individual model are stored separately in .tuning_results folder.
Using balancing config: UnderSamplingConfig Using <class 'sklearn.model_selection._search.RandomizedSearchCV'> with n_iter=250 Using <class 'sklearn.model_selection._search.RandomizedSearchCV'> with n_iter=250
| best_score | best_params | search_type | model_config_reference | |
|---|---|---|---|---|
| model_key | ||||
| XGBoostCatF1UndersampleAuto | 0.191422 | {'model__scale_pos_weight': 1, 'model__n_estimators': 250, 'model__min_child_weight': 1.5, 'model__max_depth': 6, 'model__learning_rate': 0.01, 'model__gamma': 0.3} | Random | XGBoostCatF1UndersampleAuto(model=<class 'xgboost.sklearn.XGBClassifier'>, supports_nan=True, param_grid={'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1], 'model__max_depth': [4, 5, 6, 7, 10, 12, None], 'model__n_estimators': [50, 100, 150, 200, 250], 'model__min_child_weight': [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3], 'model__gamma': [0, 0.05, 0.1, 0.3, 0.4], 'model__scale_pos_weight': [1, 5, 10, 20, 25, 30, 35, 40]}, builtin_params={'enable_categorical': True}, search_n_iter=250, balancing_config=UnderSamplingConfig(params={}), preprocessing=FunctionTransformer(func=<function preprocessing_for_xgboost.<locals>.convert_to_category at 0x7f6db42104c0>), tunning_func_target=make_scorer(f1_score, pos_label=1), best_params=None, ensemble_classifier=None) |
| XGBoostTuneCatFBeta_25 | 0.433492 | {'model__scale_pos_weight': 25, 'model__n_estimators': 250, 'model__min_child_weight': 1.5, 'model__max_depth': 4, 'model__learning_rate': 0.01, 'model__gamma': 0.1} | Random | XGBoostTuneCatFBeta_25(model=<class 'xgboost.sklearn.XGBClassifier'>, supports_nan=True, param_grid={'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1], 'model__max_depth': [4, 5, 6, 7, 10, 12, None], 'model__n_estimators': [50, 100, 150, 200, 250], 'model__min_child_weight': [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3], 'model__gamma': [0, 0.05, 0.1, 0.3, 0.4], 'model__scale_pos_weight': [1, 5, 10, 20, 25, 30, 35, 40]}, builtin_params={'enable_categorical': True}, search_n_iter=250, balancing_config=None, preprocessing=FunctionTransformer(func=<function preprocessing_for_xgboost.<locals>.convert_to_category at 0x7f6db42104c0>), tunning_func_target=make_scorer(fbeta_score, beta=2.5, pos_label=1), best_params=None, ensemble_classifier=None) |
<module 'shared.definitions' from '/home/paulius/data/projects/health_m3_s2/shared/definitions.py'>
LGBMForestBaseConfigTuneFBeta_25: 1.8 seconds Using balancing config: UnderSamplingConfig Using balancing config: UnderSamplingConfig Using balancing config: UnderSamplingConfig XGBoostCatF1UndersampleAuto: 0.6 seconds Using balancing config: SmoteConfig Using balancing config: SmoteConfig Using balancing config: SmoteConfig Using balancing config: SmoteConfig Ensemble_Log_KNN_SVM_SMOTE: 29.7 seconds XGBoostTuneCatFBeta_25: 0.6 seconds XGBoostTuneCatFBeta_325: 0.6 seconds XGBoostTuneCatFBeta_40: 0.6 seconds XGBoostTuneCatFBeta_50: 0.6 seconds XGBoostTuneRecall: 0.5 seconds CatBoostBaseConfigTuneFBeta_15: 1.4 seconds CatBoostBaseConfigTuneFBeta_20: 0.5 seconds CatBoostBaseConfigTuneFBeta_25: 1.6 seconds CatBoostBaseConfigTuneFBeta_325: 0.5 seconds CatBoostBaseConfigTuneFBeta_40: 0.6 seconds CatBoostBaseConfigTuneRecall: 0.5 seconds
Results¶
The table below shows the results for each configuration using the optimal parameters:
| accuracy | precision_macro | recall_macro | f1_macro | target_f1 | target_recall | target_precision | fbeta_1.5 | fbeta_2.5 | fbeta_4.0 | n_samples | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| XGBoostTuneCatFBeta_25 | 0.729 | 0.550 | 0.758 | 0.518 | 0.199 | 0.789 | 0.114 | 0.279 | 0.434 | 0.585 | 4908.0 |
| XGBoostCatF1UndersampleAuto | 0.712 | 0.548 | 0.754 | 0.508 | 0.191 | 0.799 | 0.109 | 0.270 | 0.426 | 0.582 | 4908.0 |
| XGBoostTuneRecall | 0.717 | 0.542 | 0.715 | 0.503 | 0.177 | 0.713 | 0.101 | 0.249 | 0.388 | 0.525 | 4908.0 |
| Ensemble_Log_KNN_SVM_SMOTE | 0.844 | 0.544 | 0.635 | 0.548 | 0.182 | 0.407 | 0.117 | 0.231 | 0.303 | 0.355 | 4908.0 |
| XGBoostTuneCatFBeta_325 | 0.897 | 0.561 | 0.619 | 0.576 | 0.207 | 0.316 | 0.153 | 0.238 | 0.276 | 0.297 | 4908.0 |
| XGBoostTuneCatFBeta_40 | 0.897 | 0.561 | 0.619 | 0.576 | 0.207 | 0.316 | 0.153 | 0.238 | 0.276 | 0.297 | 4908.0 |
| XGBoostTuneCatFBeta_50 | 0.897 | 0.561 | 0.619 | 0.576 | 0.207 | 0.316 | 0.153 | 0.238 | 0.276 | 0.297 | 4908.0 |
| CatBoostBaseConfigTuneFBeta_25 | 0.707 | 0.518 | 0.593 | 0.472 | 0.120 | 0.469 | 0.069 | 0.168 | 0.260 | 0.349 | 4908.0 |
| CatBoostBaseConfigTuneFBeta_15 | 0.710 | 0.518 | 0.592 | 0.473 | 0.120 | 0.464 | 0.069 | 0.168 | 0.259 | 0.347 | 4908.0 |
| CatBoostBaseConfigTuneFBeta_20 | 0.710 | 0.518 | 0.592 | 0.473 | 0.120 | 0.464 | 0.069 | 0.168 | 0.259 | 0.347 | 4908.0 |
| LGBMForestBaseConfigTuneFBeta_25 | 0.330 | 0.511 | 0.557 | 0.281 | 0.093 | 0.804 | 0.049 | 0.141 | 0.258 | 0.423 | 4908.0 |
| CatBoostBaseConfigTuneFBeta_325 | 0.360 | 0.510 | 0.554 | 0.299 | 0.092 | 0.766 | 0.049 | 0.140 | 0.254 | 0.412 | 4908.0 |
| CatBoostBaseConfigTuneFBeta_40 | 0.360 | 0.510 | 0.554 | 0.299 | 0.092 | 0.766 | 0.049 | 0.140 | 0.254 | 0.412 | 4908.0 |
| CatBoostBaseConfigTuneRecall | 0.360 | 0.510 | 0.554 | 0.299 | 0.092 | 0.766 | 0.049 | 0.140 | 0.254 | 0.412 | 4908.0 |
The figure layout has changed to tight
The behaviour of the precision-recall curve for all models indicates both very poor performance (precision is very low at all thresholds). Additionally, the curves are all:
- non-monotonic, i.e., they change direction on the Y axis several times as the threshold is changed, due to fluctuating true and false positives.
- precision quickly drops (even at very low thresholds) and varies significantly due to the model's inability to consistently identify the sparse positive cases in the heavily imbalanced dataset.
Selecting the "Best" Model¶
We have been able to get relatively comparable results with all the complex boost model and our ensemble model performs relatively similarly as long as some oversampling technique like SMOTE is used. With additional tuning it might provide effectively the same performance as XGBoost or CatBoost. However, the training of the (LogisticRegression + KNeighborsClassifier + SVC) is very slow so it would still be much more practical to use complex model which handles balancing etc. directly.
As far as perfomance as hyperparemeter tunning goes the only parameter that really matters is class weight which directly affects the recall / precision ratio (based on our select fbeta value for scoring).
Having that mind we have selected: TODO as our production model, while it's overal performance is not ideal it still provides reasonable performances realtive to your assumptions outlayed previously.
XX% recall relative to XX% precision means that for every person with stroke=1 we will also select ~{N} individuals as "high risk"
Model Feature Importance and SHAP plots¶
We'll use SHAP values to further analyze the importance of each feature:
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
Probability Thresholds¶
An approach that might mitigate the precision / recall issue is to further split the risk group identifed by our model into separate "Low", "Medium", "High" risk categories which would allow us to more effectively use the resources by giving more focus ot individuals who have the highest risk:
The chart shows the performance of the if only individual with stroke Prob. > T are selected. Additionally the overlay indicates the number of people whose predicted P is in an given range. The overlays can be used to selected the most at risk individual based on the probability predicted for them
Conclusion¶
- We have tried multiple different ML models to predict the insurance columns
- While the overall performance is reasonable good (F1 > 0.8) the model underestimates the TravelInsurance = True class
- This is a big issue for our client because we can only identify around 60% of all potential clients.
- On the positive our model is very good at identifying people who don't need travel insurance (almost 95% in the best case) which means that we can only contact the people who are likely to buy it which results in very high efficiency of our sales team.
Limitations and Suggestions for Future Improvements:¶
Business Case/Interpretation¶
- A deeper cost based analysis should be performed (ideally including based on data from specific insurance companies/government healthcare systems/etc.) to determine the acceptable precision/recall ratio. While the direct and indirect cost of an individual suffering a stroke might be high:
- It's not clear what real benefits identifying individual stroke victims provides. If it's mostly related to lifestyle choices additional treatment and monitoring would not be particularly useful if the patients are unwilling to alter their lifestyles.
- Potentially this model can be used on an app targetting consumers for self identification purposes (i.e. to alter lifestyle choices)
Technical¶
- Tunning for 'log_loss' instead of a classification metric.
- Tweaking the threshold and using that while hyperpaemter tunning might be beneficial: -
- Using AUCPR for tunning
- Overfitting hyperparamters like 'early_stopping_rounds' can be utilized to cut model training early {TODO}